{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "524620c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp tsdataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15392f6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12fa25a4",
   "metadata": {},
   "source": [
    "# PyTorch Dataset/Loader\n",
    "> Torch Dataset for Time Series\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2508f7a9-1433-4ad8-8f2f-0078c6ed6c3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from fastcore.test import test_eq\n",
    "from nbdev.showdoc import show_doc\n",
    "from neuralforecast.utils import generate_series"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44065066-e72a-431f-938f-1528adef9fe8",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import warnings\n",
    "from collections.abc import Mapping\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pytorch_lightning as pl\n",
    "import torch\n",
    "import utilsforecast.processing as ufp\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from utilsforecast.compat import DataFrame, pl_Series"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "323a7a6e-38c3-496d-8f1e-cad05f643d41",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class TimeSeriesLoader(DataLoader):\n",
    "    \"\"\"TimeSeriesLoader DataLoader.\n",
    "    [Source code](https://github.com/Nixtla/neuralforecast1/blob/main/neuralforecast/tsdataset.py).\n",
    "\n",
    "    Small change to PyTorch's Data loader. \n",
    "    Combines a dataset and a sampler, and provides an iterable over the given dataset.\n",
    "\n",
    "    The class `~torch.utils.data.DataLoader` supports both map-style and\n",
    "    iterable-style datasets with single- or multi-process loading, customizing\n",
    "    loading order and optional automatic batching (collation) and memory pinning.    \n",
    "    \n",
    "    **Parameters:**<br>\n",
    "    `batch_size`: (int, optional): how many samples per batch to load (default: 1).<br>\n",
    "    `shuffle`: (bool, optional): set to `True` to have the data reshuffled at every epoch (default: `False`).<br>\n",
    "    `sampler`: (Sampler or Iterable, optional): defines the strategy to draw samples from the dataset.<br>\n",
    "                Can be any `Iterable` with `__len__` implemented. If specified, `shuffle` must not be specified.<br>\n",
    "    \"\"\"\n",
    "    def __init__(self, dataset, **kwargs):\n",
    "        if 'collate_fn' in kwargs:\n",
    "            kwargs.pop('collate_fn')\n",
    "        kwargs_ = {**kwargs, **dict(collate_fn=self._collate_fn)}\n",
    "        DataLoader.__init__(self, dataset=dataset, **kwargs_)\n",
    "    \n",
    "    def _collate_fn(self, batch):\n",
    "        elem = batch[0]\n",
    "        elem_type = type(elem)\n",
    "\n",
    "        if isinstance(elem, torch.Tensor):\n",
    "            out = None\n",
    "            if torch.utils.data.get_worker_info() is not None:\n",
    "                # If we're in a background process, concatenate directly into a\n",
    "                # shared memory tensor to avoid an extra copy\n",
    "                numel = sum(x.numel() for x in batch)\n",
    "                storage = elem.storage()._new_shared(numel, device=elem.device)\n",
    "                out = elem.new(storage).resize_(len(batch), *list(elem.size()))\n",
    "            return torch.stack(batch, 0, out=out)\n",
    "\n",
    "        elif isinstance(elem, Mapping):\n",
    "            if elem['static'] is None:\n",
    "                return dict(temporal=self.collate_fn([d['temporal'] for d in batch]),\n",
    "                            temporal_cols = elem['temporal_cols'],\n",
    "                            y_idx=elem['y_idx'])\n",
    "            \n",
    "            return dict(static=self.collate_fn([d['static'] for d in batch]),\n",
    "                        static_cols = elem['static_cols'],\n",
    "                        temporal=self.collate_fn([d['temporal'] for d in batch]),\n",
    "                        temporal_cols = elem['temporal_cols'],\n",
    "                        y_idx=elem['y_idx'])\n",
    "\n",
    "        raise TypeError(f'Unknown {elem_type}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93e94050-0290-43ad-9a73-c4626bba9541",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(TimeSeriesLoader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05687429-c139-44c0-adb9-097c616908cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class TimeSeriesDataset(Dataset):\n",
    "\n",
    "    def __init__(self,\n",
    "                 temporal,\n",
    "                 temporal_cols,\n",
    "                 indptr,\n",
    "                 max_size: int,\n",
    "                 min_size: int,\n",
    "                 y_idx: int,\n",
    "                 static=None,\n",
    "                 static_cols=None,\n",
    "                 sorted=False,\n",
    "                ):\n",
    "        super().__init__()\n",
    "        self.temporal = torch.tensor(temporal, dtype=torch.float)\n",
    "        self.temporal_cols = pd.Index(list(temporal_cols))\n",
    "\n",
    "        if static is not None:\n",
    "            self.static = torch.tensor(static, dtype=torch.float)\n",
    "            self.static_cols = static_cols\n",
    "        else:\n",
    "            self.static = static\n",
    "            self.static_cols = static_cols\n",
    "\n",
    "        self.indptr = indptr\n",
    "        self.n_groups = self.indptr.size - 1\n",
    "        self.max_size = max_size\n",
    "        self.min_size = min_size\n",
    "        self.y_idx = y_idx\n",
    "\n",
    "        # Upadated flag. To protect consistency, dataset can only be updated once\n",
    "        self.updated = False\n",
    "        self.sorted = sorted\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        if isinstance(idx, int):\n",
    "            # Parse temporal data and pad its left\n",
    "            temporal = torch.zeros(size=(len(self.temporal_cols), self.max_size),\n",
    "                                   dtype=torch.float32)\n",
    "            ts = self.temporal[self.indptr[idx] : self.indptr[idx + 1], :]\n",
    "            temporal[:len(self.temporal_cols), -len(ts):] = ts.permute(1, 0)\n",
    "\n",
    "            # Add static data if available\n",
    "            static = None if self.static is None else self.static[idx,:]\n",
    "\n",
    "            item = dict(temporal=temporal, temporal_cols=self.temporal_cols,\n",
    "                        static=static, static_cols=self.static_cols,\n",
    "                        y_idx=self.y_idx)\n",
    "\n",
    "            return item\n",
    "        raise ValueError(f'idx must be int, got {type(idx)}')\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.n_groups\n",
    "\n",
    "    def __repr__(self):\n",
    "        return f'TimeSeriesDataset(n_data={self.temporal.shape[0]:,}, n_groups={self.n_groups:,})'\n",
    "\n",
    "    def __eq__(self, other):\n",
    "        if not hasattr(other, 'data') or not hasattr(other, 'indptr'):\n",
    "            return False\n",
    "        return np.allclose(self.data, other.data) and np.array_equal(self.indptr, other.indptr)\n",
    "\n",
    "\n",
    "\n",
    "    def align(self, df: DataFrame, id_col: str, time_col: str, target_col: str) -> 'TimeSeriesDataset':\n",
    "        # Protect consistency\n",
    "        df = ufp.copy_if_pandas(df, deep=False)\n",
    "\n",
    "        # Add Nones to missing columns (without available_mask)\n",
    "        temporal_cols = self.temporal_cols.copy()\n",
    "        for col in temporal_cols:\n",
    "            if col not in df.columns:\n",
    "                df = ufp.assign_columns(df, col, np.nan)\n",
    "            if col == 'available_mask':\n",
    "                df = ufp.assign_columns(df, col, 1.0)\n",
    "        \n",
    "        # Sort columns to match self.temporal_cols (without available_mask)\n",
    "        df = df[ [id_col, time_col] + temporal_cols.tolist() ]\n",
    "\n",
    "        # Process future_df\n",
    "        dataset, *_ = TimeSeriesDataset.from_df(\n",
    "            df=df,\n",
    "            sort_df=self.sorted,\n",
    "            id_col=id_col,\n",
    "            time_col=time_col,\n",
    "            target_col=target_col,\n",
    "        )\n",
    "        return dataset\n",
    "\n",
    "    def append(self, futr_dataset: 'TimeSeriesDataset') -> 'TimeSeriesDataset':\n",
    "        \"\"\"Add future observations to the dataset. Returns a copy\"\"\"\n",
    "        if self.indptr.size != futr_dataset.indptr.size:\n",
    "            raise ValueError('Cannot append `futr_dataset` with different number of groups.')\n",
    "        # Define and fill new temporal with updated information\n",
    "        len_temporal, col_temporal = self.temporal.shape\n",
    "        len_futr = futr_dataset.temporal.shape[0]\n",
    "        new_temporal = torch.empty(size=(len_temporal + len_futr, col_temporal))\n",
    "        new_sizes = np.diff(self.indptr) + np.diff(futr_dataset.indptr)\n",
    "        new_indptr = np.append(0, new_sizes.cumsum()).astype(np.int32)\n",
    "        new_max_size = np.max(new_sizes)\n",
    "\n",
    "        for i in range(self.n_groups):\n",
    "            curr_slice = slice(self.indptr[i], self.indptr[i + 1])\n",
    "            curr_size = curr_slice.stop - curr_slice.start\n",
    "            futr_slice = slice(futr_dataset.indptr[i], futr_dataset.indptr[i + 1])\n",
    "            new_temporal[new_indptr[i] : new_indptr[i] + curr_size] = self.temporal[curr_slice]\n",
    "            new_temporal[new_indptr[i] + curr_size : new_indptr[i + 1]] = futr_dataset.temporal[futr_slice]\n",
    "        \n",
    "        # Define new dataset\n",
    "        updated_dataset = TimeSeriesDataset(temporal=new_temporal,\n",
    "                                            temporal_cols=self.temporal_cols.copy(),\n",
    "                                            indptr=new_indptr,\n",
    "                                            max_size=new_max_size,\n",
    "                                            min_size=self.min_size,\n",
    "                                            static=self.static,\n",
    "                                            y_idx=self.y_idx,\n",
    "                                            static_cols=self.static_cols,\n",
    "                                            sorted=self.sorted)\n",
    "\n",
    "        return updated_dataset\n",
    "\n",
    "    @staticmethod\n",
    "    def update_dataset(dataset, futr_df, id_col='unique_id', time_col='ds', target_col='y'):\n",
    "        futr_dataset = dataset.align(\n",
    "            futr_df, id_col=id_col, time_col=time_col, target_col=target_col\n",
    "        )\n",
    "        return dataset.append(futr_dataset)\n",
    "    \n",
    "    @staticmethod\n",
    "    def trim_dataset(dataset, left_trim: int = 0, right_trim: int = 0):\n",
    "        \"\"\"\n",
    "        Trim temporal information from a dataset.\n",
    "        Returns temporal indexes [t+left:t-right] for all series.\n",
    "        \"\"\"\n",
    "        if dataset.min_size <= left_trim + right_trim:\n",
    "            raise Exception(f'left_trim + right_trim ({left_trim} + {right_trim}) \\\n",
    "                                must be lower than the shorter time series ({dataset.min_size})')\n",
    "\n",
    "        # Define and fill new temporal with trimmed information        \n",
    "        len_temporal, col_temporal = dataset.temporal.shape\n",
    "        total_trim = (left_trim + right_trim) * dataset.n_groups\n",
    "        new_temporal = torch.zeros(size=(len_temporal-total_trim, col_temporal))\n",
    "        new_indptr = [0]\n",
    "\n",
    "        acum = 0\n",
    "        for i in range(dataset.n_groups):\n",
    "            series_length = dataset.indptr[i + 1] - dataset.indptr[i]\n",
    "            new_length = series_length - left_trim - right_trim\n",
    "            new_temporal[acum:(acum+new_length), :] = dataset.temporal[dataset.indptr[i]+left_trim : \\\n",
    "                                                                       dataset.indptr[i + 1]-right_trim, :]\n",
    "            acum += new_length\n",
    "            new_indptr.append(acum)\n",
    "\n",
    "        new_max_size = dataset.max_size-left_trim-right_trim\n",
    "        new_min_size = dataset.min_size-left_trim-right_trim\n",
    "        \n",
    "        # Define new dataset\n",
    "        updated_dataset = TimeSeriesDataset(temporal=new_temporal,\n",
    "                                            temporal_cols= dataset.temporal_cols.copy(),\n",
    "                                            indptr=np.array(new_indptr).astype(np.int32),\n",
    "                                            max_size=new_max_size,\n",
    "                                            min_size=new_min_size,\n",
    "                                            y_idx=dataset.y_idx,\n",
    "                                            static=dataset.static,\n",
    "                                            static_cols=dataset.static_cols,\n",
    "                                            sorted=dataset.sorted)\n",
    "\n",
    "        return updated_dataset\n",
    "\n",
    "    @staticmethod\n",
    "    def from_df(df, static_df=None, sort_df=False, id_col='unique_id', time_col='ds', target_col='y'):\n",
    "        # TODO: protect on equality of static_df + df indexes\n",
    "        if isinstance(df, pd.DataFrame) and df.index.name == id_col:\n",
    "            warnings.warn(\n",
    "                \"Passing the id as index is deprecated, please provide it as a column instead.\",\n",
    "                FutureWarning,\n",
    "            )\n",
    "            df = df.reset_index(id_col)\n",
    "        # Define indexes if not given\n",
    "        if static_df is not None:\n",
    "            if isinstance(static_df, pd.DataFrame) and static_df.index.name == id_col:\n",
    "                warnings.warn(\n",
    "                    \"Passing the id as index is deprecated, please provide it as a column instead.\",\n",
    "                    FutureWarning,\n",
    "                )\n",
    "            if sort_df:\n",
    "                static_df = ufp.sort(static_df, by=id_col)\n",
    "\n",
    "        ids, times, data, indptr, sort_idxs = ufp.process_df(df, id_col, time_col, target_col)\n",
    "        # processor sets y as the first column\n",
    "        temporal_cols = pd.Index(\n",
    "            [target_col] + [c for c in df.columns if c not in (id_col, time_col, target_col)]\n",
    "        )\n",
    "        temporal = data.astype(np.float32, copy=False)\n",
    "        indices = ids\n",
    "        if isinstance(df, pd.DataFrame):\n",
    "            dates = pd.Index(times, name=time_col)\n",
    "        else:\n",
    "            dates = pl_Series(time_col, times)\n",
    "        sizes = np.diff(indptr)\n",
    "        max_size = max(sizes)\n",
    "        min_size = min(sizes)\n",
    "\n",
    "        # Add Available mask efficiently (without adding column to df)\n",
    "        if 'available_mask' not in df.columns:\n",
    "            available_mask = np.ones((len(temporal),1), dtype=np.float32)\n",
    "            temporal = np.append(temporal, available_mask, axis=1)\n",
    "            temporal_cols = temporal_cols.append(pd.Index(['available_mask']))\n",
    "\n",
    "        # Static features\n",
    "        if static_df is not None:\n",
    "            static_cols = [col for col in static_df.columns if col != id_col]\n",
    "            static = ufp.to_numpy(static_df[static_cols])\n",
    "            static_cols = pd.Index(static_cols)\n",
    "        else:\n",
    "            static = None\n",
    "            static_cols = None\n",
    "\n",
    "        dataset = TimeSeriesDataset(\n",
    "            temporal=temporal,\n",
    "            temporal_cols=temporal_cols,\n",
    "            static=static,\n",
    "            static_cols=static_cols,\n",
    "            indptr=indptr,\n",
    "            max_size=max_size,\n",
    "            min_size=min_size,\n",
    "            sorted=sort_df,\n",
    "            y_idx=0,\n",
    "        )\n",
    "        ds = df[time_col].to_numpy()\n",
    "        if sort_idxs is not None:\n",
    "            ds = ds[sort_idxs]\n",
    "        return dataset, indices, dates, ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61a818bf-28d2-4561-8036-475f6fe78d0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(TimeSeriesDataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52c07552-b6fa-4d10-8792-71743dcdfd1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "\n",
    "# Testing sort_df=True functionality\n",
    "temporal_df = generate_series(n_series=1000, \n",
    "                         n_temporal_features=0, equal_ends=False)\n",
    "sorted_temporal_df = temporal_df.sort_values(['unique_id', 'ds'])\n",
    "unsorted_temporal_df = sorted_temporal_df.sample(frac=1.0)\n",
    "dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=unsorted_temporal_df,\n",
    "                                                        sort_df=True)\n",
    "\n",
    "np.testing.assert_allclose(dataset.temporal[:,:-1], \n",
    "                           sorted_temporal_df.drop(columns=['unique_id', 'ds']).values)\n",
    "test_eq(indices, pd.Series(sorted_temporal_df['unique_id'].unique()))\n",
    "test_eq(dates, temporal_df.groupby('unique_id')['ds'].max().values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4dae43c-4d11-4bbc-a431-ac33b004859a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class TimeSeriesDataModule(pl.LightningDataModule):\n",
    "    \n",
    "    def __init__(\n",
    "            self, \n",
    "            dataset: TimeSeriesDataset,\n",
    "            batch_size=32, \n",
    "            valid_batch_size=1024,\n",
    "            num_workers=0,\n",
    "            drop_last=False\n",
    "        ):\n",
    "        super().__init__()\n",
    "        self.dataset = dataset\n",
    "        self.batch_size = batch_size\n",
    "        self.valid_batch_size = valid_batch_size\n",
    "        self.num_workers = num_workers\n",
    "        self.drop_last = drop_last\n",
    "    \n",
    "    def train_dataloader(self):\n",
    "        loader = TimeSeriesLoader(\n",
    "            self.dataset, \n",
    "            batch_size=self.batch_size, \n",
    "            num_workers=self.num_workers,\n",
    "            shuffle=True,\n",
    "            drop_last=self.drop_last\n",
    "        )\n",
    "        return loader\n",
    "    \n",
    "    def val_dataloader(self):\n",
    "        loader = TimeSeriesLoader(\n",
    "            self.dataset, \n",
    "            batch_size=self.valid_batch_size, \n",
    "            num_workers=self.num_workers,\n",
    "            shuffle=False,\n",
    "            drop_last=self.drop_last\n",
    "        )\n",
    "        return loader\n",
    "    \n",
    "    def predict_dataloader(self):\n",
    "        loader = TimeSeriesLoader(\n",
    "            self.dataset,\n",
    "            batch_size=self.valid_batch_size, \n",
    "            num_workers=self.num_workers,\n",
    "            shuffle=False\n",
    "        )\n",
    "        return loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8535a15f-b5cf-4ca1-bfa2-e53a9e8c3bc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(TimeSeriesDataModule)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b534d29d-eecc-43ba-8468-c23305fa24a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "\n",
    "batch_size = 128\n",
    "data = TimeSeriesDataModule(dataset=dataset, \n",
    "                            batch_size=batch_size, drop_last=True)\n",
    "for batch in data.train_dataloader():\n",
    "    test_eq(batch['temporal'].shape, (batch_size, 2, 500))\n",
    "    test_eq(batch['temporal_cols'], ['y', 'available_mask'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4481272a-ea3a-4b63-8f14-9445d8f41338",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "\n",
    "batch_size = 128\n",
    "n_static_features = 2\n",
    "n_temporal_features = 4\n",
    "temporal_df, static_df = generate_series(n_series=1000,\n",
    "                                         n_static_features=n_static_features,\n",
    "                                         n_temporal_features=n_temporal_features, \n",
    "                                         equal_ends=False)\n",
    "\n",
    "dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=temporal_df,\n",
    "                                                        static_df=static_df,\n",
    "                                                        sort_df=True)\n",
    "data = TimeSeriesDataModule(dataset=dataset,\n",
    "                            batch_size=batch_size, drop_last=True)\n",
    "\n",
    "for batch in data.train_dataloader():\n",
    "    test_eq(batch['temporal'].shape, (batch_size, n_temporal_features + 2, 500))\n",
    "    test_eq(batch['temporal_cols'],\n",
    "            ['y'] + [f'temporal_{i}' for i in range(n_temporal_features)] + ['available_mask'])\n",
    "    \n",
    "    test_eq(batch['static'].shape, (batch_size, n_static_features))\n",
    "    test_eq(batch['static_cols'], [f'static_{i}' for i in range(n_static_features)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "252b59f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "\n",
    "# Testing sort_df=True functionality\n",
    "temporal_df = generate_series(n_series=2,\n",
    "                              n_temporal_features=2, equal_ends=True)\n",
    "temporal_df = temporal_df.groupby('unique_id').tail(10)\n",
    "temporal_df = temporal_df.reset_index()\n",
    "temporal_full_df = temporal_df.sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
    "temporal_full_df.loc[temporal_full_df.ds > '2001-05-11', ['y', 'temporal_0']] = None\n",
    "\n",
    "split1_df = temporal_full_df.loc[temporal_full_df.ds <= '2001-05-11']\n",
    "split2_df = temporal_full_df.loc[temporal_full_df.ds > '2001-05-11']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6eab7367",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "\n",
    "# Testing available mask\n",
    "temporal_df_w_mask = temporal_df.copy()\n",
    "temporal_df_w_mask['available_mask'] = 1\n",
    "\n",
    "# Mask with all 1's\n",
    "dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=temporal_df_w_mask,\n",
    "                                                        sort_df=True)\n",
    "mask_average = dataset.temporal[:, -1].mean()\n",
    "np.testing.assert_almost_equal(mask_average, 1.0000)\n",
    "\n",
    "# Add 0's to available mask\n",
    "temporal_df_w_mask.loc[temporal_df_w_mask.ds > '2001-05-11', 'available_mask'] = 0\n",
    "dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=temporal_df_w_mask,\n",
    "                                                        sort_df=True)\n",
    "mask_average = dataset.temporal[:, -1].mean()\n",
    "np.testing.assert_almost_equal(mask_average, 0.7000)\n",
    "\n",
    "# Available mask not in last column\n",
    "temporal_df_w_mask = temporal_df_w_mask[['unique_id','ds','y','available_mask', 'temporal_0','temporal_1']]\n",
    "dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=temporal_df_w_mask,\n",
    "                                                        sort_df=True)\n",
    "mask_average = dataset.temporal[:, 1].mean()\n",
    "np.testing.assert_almost_equal(mask_average, 0.7000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0d23f1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# To test correct future_df wrangling of the `update_df` method\n",
    "# We are checking that we are able to recover the AirPassengers dataset\n",
    "# using the dataframe or splitting it into parts and initializing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39f999c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "\n",
    "# FULL DATASET\n",
    "dataset_full, indices_full, dates_full, ds_full = TimeSeriesDataset.from_df(df=temporal_full_df,\n",
    "                                                                            sort_df=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30f927e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "\n",
    "# SPLIT_1 DATASET\n",
    "dataset_1, indices_1, dates_1, ds_1 = TimeSeriesDataset.from_df(df=split1_df,\n",
    "                                                                sort_df=False)\n",
    "dataset_1 = dataset_1.update_dataset(dataset_1, split2_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "468a6879",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "\n",
    "np.testing.assert_almost_equal(dataset_full.temporal.numpy(), dataset_1.temporal.numpy())\n",
    "test_eq(dataset_full.max_size, dataset_1.max_size)\n",
    "test_eq(dataset_full.indptr, dataset_1.indptr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "556f852c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "\n",
    "# Testing trim_dataset functionality\n",
    "n_static_features = 0\n",
    "n_temporal_features = 2\n",
    "temporal_df = generate_series(n_series=100,\n",
    "                              min_length=50,\n",
    "                              max_length=100,\n",
    "                              n_static_features=n_static_features,\n",
    "                              n_temporal_features=n_temporal_features, \n",
    "                              equal_ends=False)\n",
    "dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=temporal_df,\n",
    "                                                        static_df=static_df,\n",
    "                                                        sort_df=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db7b1a51",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "left_trim = 10\n",
    "right_trim = 20\n",
    "dataset_trimmed = dataset.trim_dataset(dataset, left_trim=left_trim, right_trim=right_trim)\n",
    "\n",
    "np.testing.assert_almost_equal(dataset.temporal[dataset.indptr[50]+left_trim:dataset.indptr[51]-right_trim].numpy(),\n",
    "                               dataset_trimmed.temporal[dataset_trimmed.indptr[50]:dataset_trimmed.indptr[51]].numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "624a3fbb-cb78-4440-a645-54699fd82660",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "#| polars\n",
    "import polars"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1bdd479-b4c7-4a40-93eb-2b7c9b969a80",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "#| polars\n",
    "temporal_df2 = temporal_df.copy()\n",
    "for col in ('unique_id', 'temporal_0', 'temporal_1'):\n",
    "    temporal_df2[col] = temporal_df2[col].cat.codes\n",
    "temporal_pl = polars.from_pandas(temporal_df2).sample(fraction=1.0)\n",
    "static_pl = polars.from_pandas(static_df.assign(unique_id=lambda df: df['unique_id'].astype('int64')))\n",
    "dataset_pl, indices_pl, dates_pl, ds_pl = TimeSeriesDataset.from_df(df=temporal_pl, static_df=static_df, sort_df=True)\n",
    "for attr in ('static_cols', 'temporal_cols', 'min_size', 'max_size', 'n_groups'):\n",
    "    test_eq(getattr(dataset, attr), getattr(dataset_pl, attr))\n",
    "torch.testing.assert_allclose(dataset.temporal, dataset_pl.temporal)\n",
    "torch.testing.assert_allclose(dataset.static, dataset_pl.static)\n",
    "pd.testing.assert_series_equal(indices.astype('int64'), indices_pl.to_pandas().astype('int64'))\n",
    "pd.testing.assert_index_equal(dates, pd.Index(dates_pl, name='ds'))\n",
    "np.testing.assert_array_equal(ds, ds_pl)\n",
    "np.testing.assert_array_equal(dataset.indptr, dataset_pl.indptr)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
